import torch
import torch.nn as nn
from torchvision import models

def get_model(model_name: str= 'resnet-50', classes: int= 50):
    # if model_name== 'resnet-50':
    model = models.resnet50(pretrained= False)
    
    model.fc = nn.Linear(model.fc.in_features, 50)
    return model

if __name__ == '__main__':
    model = get_model()